#!/usr/bin/env python
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#  * Redistributions of source code must retain the above copyright
#    notice, this list of conditions and the following disclaimer.
#  * Redistributions in binary form must reproduce the above copyright
#    notice, this list of conditions and the following disclaimer in the
#    documentation and/or other materials provided with the distribution.
#  * Neither the name of NVIDIA CORPORATION nor the names of its
#    contributors may be used to endorse or promote products derived
#    from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import argparse
import sys
from builtins import range

import numpy as np
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
import tritonclient.utils as utils
import tritonclient.utils.shared_memory as shm

FLAGS = None


def infer_and_validata(use_shared_memory, orig_input0_data, orig_input1_data):
    if use_shared_memory:
        input0_data = orig_input0_data
        input1_data = orig_input1_data
        byte_size = input0_data.size * input0_data.itemsize
        inputs[0].set_shared_memory("input0_data", byte_size)
        inputs[1].set_shared_memory("input1_data", byte_size)
        outputs[0].set_shared_memory("output0_data", byte_size)
        outputs[1].set_shared_memory("output1_data", byte_size)
    else:
        input0_data = orig_input0_data
        input1_data = orig_input1_data * 2
        inputs[0].set_data_from_numpy(np.expand_dims(input0_data, axis=0))
        inputs[1].set_data_from_numpy(np.expand_dims(input1_data, axis=0))
        outputs[0].unset_shared_memory()
        outputs[1].unset_shared_memory()

    results = triton_client.infer(model_name=model_name, inputs=inputs, outputs=outputs)

    # Read results from the shared memory.
    output0 = results.get_output("OUTPUT0")
    if output0 is not None:
        if use_shared_memory:
            if protocol == "grpc":
                output0_data = shm.get_contents_as_numpy(
                    shm_op0_handle,
                    utils.triton_to_np_dtype(output0.datatype),
                    output0.shape,
                )
            else:
                output0_data = shm.get_contents_as_numpy(
                    shm_op0_handle,
                    utils.triton_to_np_dtype(output0["datatype"]),
                    output0["shape"],
                )
        else:
            output0_data = results.as_numpy("OUTPUT0")
    else:
        print("OUTPUT0 is missing in the response.")
        sys.exit(1)

    output1 = results.get_output("OUTPUT1")
    if output1 is not None:
        if use_shared_memory:
            if protocol == "grpc":
                output1_data = shm.get_contents_as_numpy(
                    shm_op1_handle,
                    utils.triton_to_np_dtype(output1.datatype),
                    output1.shape,
                )
            else:
                output1_data = shm.get_contents_as_numpy(
                    shm_op1_handle,
                    utils.triton_to_np_dtype(output1["datatype"]),
                    output1["shape"],
                )
        else:
            output1_data = results.as_numpy("OUTPUT1")
    else:
        print("OUTPUT1 is missing in the response.")
        sys.exit(1)

    if use_shared_memory:
        print("\n\n======== SHARED_MEMORY ========\n")
    else:
        print("\n\n======== NO_SHARED_MEMORY ========\n")
    for i in range(16):
        print(
            str(input0_data[i])
            + " + "
            + str(input1_data[i])
            + " = "
            + str(output0_data[0][i])
        )
        print(
            str(input0_data[i])
            + " - "
            + str(input1_data[i])
            + " = "
            + str(output1_data[0][i])
        )
        if (input0_data[i] + input1_data[i]) != output0_data[0][i]:
            print("shm infer error: incorrect sum")
            sys.exit(1)
        if (input0_data[i] - input1_data[i]) != output1_data[0][i]:
            print("shm infer error: incorrect difference")
            sys.exit(1)
    print("\n======== END ========\n\n")


# Tests whether the same InferInput and InferRequestedOutput objects can be
# successfully used repeatedly for different inferences using/not-using
# shared memory.
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-v",
        "--verbose",
        action="store_true",
        required=False,
        default=False,
        help="Enable verbose output",
    )
    parser.add_argument(
        "-i",
        "--protocol",
        type=str,
        required=False,
        default="HTTP",
        help="Protocol (HTTP/gRPC) used to communicate with "
        + "the inference service. Default is HTTP.",
    )
    parser.add_argument(
        "-u",
        "--url",
        type=str,
        required=False,
        default="localhost:8000",
        help="Inference server URL. Default is localhost:8000.",
    )

    FLAGS = parser.parse_args()

    protocol = FLAGS.protocol.lower()

    try:
        if protocol == "grpc":
            # Create gRPC client for communicating with the server
            triton_client = grpcclient.InferenceServerClient(
                url=FLAGS.url, verbose=FLAGS.verbose
            )
        else:
            # Create HTTP client for communicating with the server
            triton_client = httpclient.InferenceServerClient(
                url=FLAGS.url, verbose=FLAGS.verbose
            )
    except Exception as e:
        print("client creation failed: " + str(e))
        sys.exit(1)

    # To make sure no shared memory regions are registered with the
    # server.
    triton_client.unregister_system_shared_memory()
    triton_client.unregister_cuda_shared_memory()

    # We use a simple model that takes 2 input tensors of 16 integers
    # each and returns 2 output tensors of 16 integers each. One
    # output tensor is the element-wise sum of the inputs and one
    # output is the element-wise difference.
    model_name = "simple"
    model_version = ""

    # Create the data for the two input tensors. Initialize the first
    # to unique integers and the second to all ones.
    input0_data = np.arange(start=0, stop=16, dtype=np.int32)
    input1_data = np.ones(shape=16, dtype=np.int32)

    input_byte_size = input0_data.size * input0_data.itemsize
    output_byte_size = input_byte_size

    # Create Output0 and Output1 in Shared Memory and store shared memory handles
    shm_op0_handle = shm.create_shared_memory_region(
        "output0_data", "/output0_simple", output_byte_size
    )
    shm_op1_handle = shm.create_shared_memory_region(
        "output1_data", "/output1_simple", output_byte_size
    )

    # Register Output0 and Output1 shared memory with Triton Server
    triton_client.register_system_shared_memory(
        "output0_data", "/output0_simple", output_byte_size
    )
    triton_client.register_system_shared_memory(
        "output1_data", "/output1_simple", output_byte_size
    )

    # Create Input0 and Input1 in Shared Memory and store shared memory handles
    shm_ip0_handle = shm.create_shared_memory_region(
        "input0_data", "/input0_simple", input_byte_size
    )
    shm_ip1_handle = shm.create_shared_memory_region(
        "input1_data", "/input1_simple", input_byte_size
    )

    # Put input data values into shared memory
    shm.set_shared_memory_region(shm_ip0_handle, [input0_data])
    shm.set_shared_memory_region(shm_ip1_handle, [input1_data])

    # Register Input0 and Input1 shared memory with Triton Server
    triton_client.register_system_shared_memory(
        "input0_data", "/input0_simple", input_byte_size
    )
    triton_client.register_system_shared_memory(
        "input1_data", "/input1_simple", input_byte_size
    )

    # Set the parameters to use data from shared memory
    inputs = []
    if protocol == "grpc":
        inputs.append(grpcclient.InferInput("INPUT0", [1, 16], "INT32"))

        inputs.append(grpcclient.InferInput("INPUT1", [1, 16], "INT32"))
    else:
        inputs.append(httpclient.InferInput("INPUT0", [1, 16], "INT32"))

        inputs.append(httpclient.InferInput("INPUT1", [1, 16], "INT32"))

    outputs = []
    if protocol == "grpc":
        outputs.append(grpcclient.InferRequestedOutput("OUTPUT0"))
        outputs.append(grpcclient.InferRequestedOutput("OUTPUT1"))
    else:
        outputs.append(httpclient.InferRequestedOutput("OUTPUT0"))
        outputs.append(httpclient.InferRequestedOutput("OUTPUT1"))

    # Use shared memory
    infer_and_validata(True, input0_data, input1_data)

    # Don't use shared memory
    infer_and_validata(False, input0_data, input1_data)

    # Use shared memory
    infer_and_validata(True, input0_data, input1_data)

    triton_client.unregister_system_shared_memory()
    shm.destroy_shared_memory_region(shm_ip0_handle)
    shm.destroy_shared_memory_region(shm_ip1_handle)
    shm.destroy_shared_memory_region(shm_op0_handle)
    shm.destroy_shared_memory_region(shm_op1_handle)
